// SPDX-License-Identifier: GPL-2.0
/*
 * GPU: VRAM on BAR
 *
 * (C) 2024.03.14 BuddyZhang1 <buddy.zhang@aliyun.com>
 */
#include <linux/init.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/miscdevice.h>
#include <linux/fs.h>
#include <linux/slab.h>
#include <linux/uaccess.h>
#include <linux/pci.h>
#include <linux/miscdevice.h>
#include <linux/ioport.h>
#include <linux/mmu_notifier.h>
#include <linux/hmm.h>

struct BiscuitOS_hmm_cmd {
	unsigned long	addr;
	unsigned long	ptr;
	unsigned long	npages;
};

#define DEV_NAME		"BiscuitOS-GPU-VRAM-BAR"
#define MMIO_BAR		0x00
#define HMM_READ		_IOWR('H', 0x00, struct BiscuitOS_hmm_cmd)

struct BiscuitOS_pci_device {
	struct pci_dev *pdev;
	/* MMIO BAR */
	void __iomem *mmio;
	/* VRAM */
	phys_addr_t mmio_base;
	phys_addr_t mmio_end;
	phys_addr_t mmio_len;

	/* HMM */
	struct dev_pagemap pagemap;
	struct page *free_pages;

	/* MMU NOTIFIER */
	struct mmu_interval_notifier notifier;
};
static struct BiscuitOS_pci_device *bpdev;

static vm_fault_t BiscuitOS_HMM_fault(struct vm_fault *vmf)
{
	printk("FAULT\n");
	return 0;
}

static void BiscuitOS_HMM_free(struct page *page)
{

}

static bool
BiscuitOS_mmu_interval_invalidate(struct mmu_interval_notifier *mni,
	const struct mmu_notifier_range *range, unsigned long cur_seq)
{
	struct BiscuitOS_pci_device *bdev =
		container_of(mni, struct BiscuitOS_pci_device, notifier);

	if (range->event == MMU_NOTIFY_MIGRATE &&
		range->owner == bdev)
		return true;

	return true;
}

static const struct mmu_interval_notifier_ops BiscuitOS_mmu_ops = {
	.invalidate = BiscuitOS_mmu_interval_invalidate,
};

static int BiscuitOS_misc_open(struct inode *inode, struct file *filp)
{
	int r;

	r = mmu_interval_notifier_insert(&bpdev->notifier, current->mm,
			0, ULONG_MAX & PAGE_MASK, &BiscuitOS_mmu_ops);

	return 0;
}

static int BiscuitOS_hmm_read(struct BiscuitOS_hmm_cmd *cmd)
{
	unsigned long pfns[64];
	struct hmm_range range = {
		.notifier = &bpdev->notifier,
		.hmm_pfns = pfns,
		.pfn_flags_mask = 0,
		.default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
		.dev_private_owner = bpdev,
		.start = cmd->addr,
		.end = cmd->addr + PAGE_SIZE,
	};

	range.notifier_seq = mmu_interval_read_begin(range.notifier);
	hmm_range_fault(&range);
	mmu_interval_read_retry(range.notifier, range.notifier_seq);

	return 0;
}

static long BiscuitOS_misc_ioctl(struct file *filp,
			unsigned int cmd, unsigned long arg)
{
	struct BiscuitOS_hmm_cmd hmm_cmd;

	if (copy_from_user(&hmm_cmd, (void __user *)arg, sizeof(hmm_cmd)))
		return -EFAULT;

	switch (cmd) {
	case HMM_READ:
		BiscuitOS_hmm_read(&hmm_cmd);
		break;
	}

	return 0;
}

static struct file_operations BiscuitOS_misc_fops = {
	.owner		= THIS_MODULE,
	.open		= BiscuitOS_misc_open,
	.unlocked_ioctl = BiscuitOS_misc_ioctl,
};

static struct miscdevice BiscuitOS_miscdev = {
	.minor	= MISC_DYNAMIC_MINOR,
	.name	= DEV_NAME,
	.fops	= &BiscuitOS_misc_fops,
};

static const struct dev_pagemap_ops BiscuitOS_HMM_ops = {
	.migrate_to_ram	= BiscuitOS_HMM_fault,
	.page_free      = BiscuitOS_HMM_free,
};

static int BiscuitOS_HMM_init(struct BiscuitOS_pci_device *bpdev)
{
	unsigned long pfn, pfn_start, pfn_last;
	struct resource *res = NULL;
	void *ptr;

	/* ENABLE: CONFIG_GET_FREE_REGION & CONFIG_DEVICE_PRIVATE */
	res = request_free_mem_region(&iomem_resource, bpdev->mmio_len,
			"BiscuitOS HMM");

	bpdev->pagemap.range.start = res->start;
	bpdev->pagemap.range.end   = res->end;
	bpdev->pagemap.type        = MEMORY_DEVICE_PRIVATE;
	bpdev->pagemap.nr_range    = 1;
	bpdev->pagemap.ops         = &BiscuitOS_HMM_ops,
	bpdev->pagemap.owner       = bpdev;

	ptr = memremap_pages(&bpdev->pagemap, numa_node_id());
	if (IS_ERR_OR_NULL(ptr)) {
		printk("HMM MEMREMAP ERROR.\n");
		return PTR_ERR(ptr);
	}

	pfn_start = bpdev->pagemap.range.start >> PAGE_SHIFT;
	pfn_last  = pfn_start +
			(range_len(&bpdev->pagemap.range) >> PAGE_SHIFT);

	for (pfn = pfn_start; pfn < pfn_last; pfn++) {
		struct page *page = pfn_to_page(pfn);

		page->zone_device_data = bpdev->free_pages;
		bpdev->free_pages = page;
	}	

	return 0;
}

static int BiscuitOS_pci_probe(struct pci_dev *pdev, 
				const struct pci_device_id *id)
{
	int r;

	bpdev = kzalloc(sizeof(*bpdev), GFP_KERNEL);
	if (!bpdev) {
		r = -ENOMEM;
		printk("%s ERROR: BiscuitOS PCI allocate failed.\n", DEV_NAME);
		goto err_alloc;
	}
	bpdev->pdev = pdev;

	/* ENABLE PCI DEVICE */
	r = pci_enable_device(pdev);
	if (r < 0) {
		printk("%s ERROR: PCI Device Enable failed.\n", DEV_NAME);
		goto err_enable_pci;
	}

	/* REMAPPING MMIO INTO KERNEL SPAGE */
	bpdev->mmio = pci_iomap(pdev, MMIO_BAR, pci_resource_len(pdev, MMIO_BAR));
	if (!bpdev->mmio) {
		r = -EBUSY;
		printk("%s ERROR: Remapping MMIO Failed\n", DEV_NAME);
		goto err_iomap;
	}

	/* SET MASTER */
	pci_set_master(pdev);

	/* VRAM */
	bpdev->mmio_base = pci_resource_start(pdev, MMIO_BAR);
	bpdev->mmio_end  = pci_resource_end(pdev, MMIO_BAR); 
	bpdev->mmio_len  = pci_resource_len(pdev, MMIO_BAR);

	printk("%s Success Register PCIe Device.\n", DEV_NAME);

	/* USERSPACE INTERFACE */
	misc_register(&BiscuitOS_miscdev);

	/* HMM */
	BiscuitOS_HMM_init(bpdev);

	return 0;

err_iomap:
	pci_disable_device(pdev);
err_enable_pci:
	kfree(bpdev);
	bpdev = NULL;
err_alloc:
	return r;
}

static void BiscuitOS_pci_remove(struct pci_dev *pdev)
{
	pci_iounmap(pdev, bpdev->mmio);
	pci_release_region(pdev, MMIO_BAR);
	pci_disable_device(pdev);
	kfree(bpdev);
	bpdev = NULL;
}

static const struct pci_device_id BiscuitOS_pci_ids[] = {
	{ PCI_DEVICE(0x1010, 0x1991), },
};

static struct pci_driver BiscuitOS_PCIe_driver = {
	.name		= DEV_NAME,
	.id_table	= BiscuitOS_pci_ids,
	.probe		= BiscuitOS_pci_probe,
	.remove		= BiscuitOS_pci_remove,
};

static int __init BiscuitOS_init(void)
{
	return pci_register_driver(&BiscuitOS_PCIe_driver);
}

static void __exit BiscuitOS_exit(void)
{
	pci_unregister_driver(&BiscuitOS_PCIe_driver);
}

module_init(BiscuitOS_init);
module_exit(BiscuitOS_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("BiscuitOS <buddy.zhang@aliyun.com>");
MODULE_DESCRIPTION("GPU");
